{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2025, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
#pragma once

#include "caspar_mappings.h"
#include "pybind_array_tools.h"

namespace caspar {

void add_casmappings_pybindings(pybind11::module_ module) {
  {% for nodetype in caslib.exposed_types %}
  module.def("{{nodetype.__name__}}_stacked_to_caspar",
      [](pybind11::object stacked_data, pybind11::object cas_data) {
        if (GetNumCols(stacked_data) != {{Ops.storage_dim(nodetype)}}) {
          throw std::runtime_error(
              "The stacked data must have {{Ops.storage_dim(nodetype)}} columns.");
        }
        if (GetNumRows(cas_data) != {{caspar_size(Ops.storage_dim(nodetype))}}) {
          throw std::runtime_error(
              "The caspar data must have {{caspar_size(Ops.storage_dim(nodetype))}} rows.");
        }
        int num_objects = GetNumRows(stacked_data);
        int cas_stride = GetNumCols(cas_data);
        if (cas_stride < num_objects) {
          throw std::runtime_error(
            "The caspar data must have at least as many columns as stacked_data has rows.");
        }
        {{nodetype.__name__}}_stacked_to_caspar(
          AsFloatPtr(stacked_data), AsFloatPtr(cas_data), cas_stride, 0,  num_objects);
      });
  module.def("{{nodetype.__name__}}_caspar_to_stacked",
      [](pybind11::object cas_data, pybind11::object stacked_data) {
          if (GetNumCols(stacked_data) != {{Ops.storage_dim(nodetype)}}) {
              throw std::runtime_error(
                  "The stacked data must have {{Ops.storage_dim(nodetype)}} columns.");
          }
          if (GetNumRows(cas_data) != {{caspar_size(Ops.storage_dim(nodetype))}}) {
              throw std::runtime_error("The caspar data must have {{caspar_size(Ops.storage_dim(nodetype))}} rows.");
          }
          int num_objects = GetNumRows(stacked_data);
          int cas_stride = GetNumCols(cas_data);
          if (cas_stride < num_objects) {
              throw std::runtime_error(
                  "The caspar data must have at least as many columns as stacked_data has rows.");
          }

          {{nodetype.__name__}}_caspar_to_stacked(
              AsFloatPtr(cas_data), AsFloatPtr(stacked_data), cas_stride, 0, num_objects);
      });
  {% endfor %}
}

} // namespace caspar
